import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import os


os.makedirs('./Visuals', exist_ok=True)

# Load the input and output .npy files
input_data = np.load('./Train Val Test/SF_1.0m/000100000100110010000001010010_1.0m.input_Main.npy')
output_data = np.load('./Train Val Test/SF_1.0m/000100000100110010000001010010_1.0m.output_Main.npy')

input_mask = input_data > 0  
output_mask = output_data > 0  

### PLOTTING INPUT ###

plt.figure(figsize=(12, 12))

input_value1_coords = np.column_stack(np.nonzero(input_data == 1))    
input_value2_coords = np.column_stack(np.nonzero(input_data == 0.5))  

plt.scatter(input_value1_coords[:, 1], input_value1_coords[:, 0], c='#adb5bd', s=5, label='Unprotected')  # Gray for value 1
plt.scatter(input_value2_coords[:, 1], input_value2_coords[:, 0], c='black', s=5, label='Protected')      # Black for value 0.5
plt.legend(loc='upper left', fontsize=16, markerscale=4, frameon=True)
plt.axis('off')
plt.xlim(0, input_data.shape[1])
plt.ylim(0, input_data.shape[0])
plt.savefig('./Visuals/Input.png', dpi=500, bbox_inches='tight', pad_inches=0)
plt.show()



### PLOTTING OUTPUT ###
colors = ['#ADFF2F', '#FFD700', '#FF8C00', '#FF0000']
cmap_name = 'custom_cmap'
custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors)

plt.figure(figsize=(15, 12))

output_coords = np.column_stack(np.nonzero(input_mask))  
pwl_values = output_data[output_coords[:, 0], output_coords[:, 1]]  

pwl = pwl_values

zero_mask = pwl == 0
non_zero_mask = pwl != 0

# Define the normalization based on non-zero values of pwl
if np.any(non_zero_mask):  # Check if there are non-zero values
    norm_pwl = plt.Normalize(vmin=0, vmax=pwl[non_zero_mask].max())
else:
    norm_pwl = plt.Normalize(vmin=pwl.min(), vmax=pwl.max())  # If all values are zero

x = output_coords[:, 1] 
y = output_coords[:, 0]

plt.scatter(x[zero_mask], y[zero_mask], c=pwl[zero_mask], cmap=custom_cmap, norm=norm_pwl, s=5)
plt.scatter(x[non_zero_mask], y[non_zero_mask], c=pwl[non_zero_mask], cmap=custom_cmap, norm=norm_pwl, s=5)

plt.xlim(0, output_data.shape[1])  
plt.ylim(0, output_data.shape[0])  

cbar = plt.colorbar()
cbar.set_label('PWL Intensity (meters above sea level)', fontsize=18, color='black')  
cbar.ax.tick_params(labelsize=18, colors='black')  

plt.tick_params(axis='both', which='major', labelsize=18, colors='black')

plt.axis('off')
plt.xlim(0, input_data.shape[1])
plt.ylim(0, input_data.shape[0])

plt.savefig('./Visuals/Output.png', dpi=500, bbox_inches='tight')

plt.show()
